import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from sklearn.neighbors import KNeighborsClassifier

# ==========================
# 1. 数据集加载与预处理（适配BSM1_WWTP_data）
# ==========================
class BSM1Dataset(Dataset):
    def __init__(self, csv_path, seq_len=10):
        data = pd.read_csv(csv_path)
        self.labels = data.iloc[:, -1].values.astype(int)
        self.data = data.iloc[:, :-1].values.astype(np.float32)
        self.seq_len = seq_len
        self.scaler = StandardScaler()
        self.data = self.scaler.fit_transform(self.data)

    def __getitem__(self, index):
        if index >= self.seq_len:
            x = self.data[index-self.seq_len:index]
        else:
            x = self.data[index:index+self.seq_len]
        x = torch.from_numpy(x).float()
        label = self.labels[index]
        return x, label

    def __len__(self):
        return len(self.labels)

# ==========================
# 2. 图神经网络组件
# ==========================
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        nn.init.zeros_(self.bias)
    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj.unsqueeze(0), support) + self.bias
        return output

# ==========================
# 3. 多头自注意力机制
# ==========================
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
    def forward(self, x):
        return self.attn(x, x, x)[0]

# ==========================
# 4. Transformer块
# ==========================
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.norm1(x + self.dropout(self.attn(x)))
        x = self.norm2(x + self.dropout(self.ff(x)))
        return x

# ==========================
# 5. GCN-Transformer模型
# ==========================
class GCN_Transformer(nn.Module):
    def __init__(self, input_dim, seq_len, embed_dim=24, num_heads=4, num_layers=2, hidden_dim=48, latent_dim=12, dropout=0.18, num_classes=11):
        super().__init__()
        self.input_dim = input_dim
        self.seq_len = seq_len
        self.latent_dim = latent_dim
        
        # Transformer编码器部分
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.encoder_layers = nn.Sequential(
            *[TransformerBlock(embed_dim, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.encoder_out = nn.Linear(embed_dim, latent_dim)
        
        # 图神经网络部分
        self.register_buffer('adj_matrix', self._create_adjacency_matrix(input_dim))
        self.gcn1 = GraphConvolution(latent_dim, latent_dim)
        self.gcn2 = GraphConvolution(latent_dim, latent_dim)
        self.gcn_activation = nn.ReLU()
        
        # 分类层
        self.classifier = nn.Linear(latent_dim, num_classes)
    
    def _create_adjacency_matrix(self, num_nodes, threshold=0.4):
        adj_matrix = torch.eye(num_nodes)
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j and torch.rand(1).item() < threshold:
                    adj_matrix[i, j] = 1.0
                    adj_matrix[j, i] = 1.0
        degree = torch.sum(adj_matrix, dim=1)
        degree_inv_sqrt = torch.pow(degree, -0.5)
        degree_inv_sqrt[degree_inv_sqrt == float('inf')] = 0
        degree_inv_sqrt_matrix = torch.diag(degree_inv_sqrt)
        normalized_adj = torch.mm(torch.mm(degree_inv_sqrt_matrix, adj_matrix), degree_inv_sqrt_matrix)
        return normalized_adj
    
    def forward(self, x):
        # x: [batch, seq_len, input_dim]
        batch_size = x.size(0)
        seq_len = x.size(1)
        
        # Transformer编码
        x_embedded = self.embedding(x)  # [batch, seq_len, embed_dim]
        encoded = self.encoder_layers(x_embedded)  # [batch, seq_len, embed_dim]
        latent = self.encoder_out(encoded)  # [batch, seq_len, latent_dim]
        
        # 将序列特征转换为节点特征
        if seq_len == self.input_dim:
            node_features = latent  # [batch, input_dim, latent_dim]
        else:
            # 若seq_len与input_dim不一致，则将seq_len维度特征投影到input_dim节点
            if seq_len > self.input_dim:
                idx = torch.linspace(0, seq_len-1, steps=self.input_dim).long().to(latent.device)
                node_features = latent[:, idx, :]
            else:
                # 重复填充
                repeat_factor = (self.input_dim + seq_len - 1) // seq_len
                node_features = latent.repeat(1, repeat_factor, 1)[:, :self.input_dim, :]
        
        # 图卷积操作
        gcn_out = self.gcn_activation(self.gcn1(node_features, self.adj_matrix))
        gcn_out = self.gcn2(gcn_out, self.adj_matrix)
        
        # 全局池化
        pooled = gcn_out.mean(dim=1)  # [batch, latent_dim]
        
        # 分类
        out = self.classifier(pooled)
        return out

def knn_accuracy(train_embeddings, train_labels, val_embeddings, val_labels, k=5):
    knn_classifier = KNeighborsClassifier(n_neighbors=k)
    knn_classifier.fit(train_embeddings, train_labels)
    train_acc = knn_classifier.score(train_embeddings, train_labels)
    val_acc = knn_classifier.score(val_embeddings, val_labels)
    return train_acc, val_acc

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 512
    seq_len = 10
    # 加载训练数据集
    train_data = BSM1Dataset('BSM1_WWTP_data/train_data.csv', seq_len=seq_len)
    
    # 从训练数据集中划分出验证集（20%）
    train_size = int(len(train_data) * 0.8)
    val_size = len(train_data) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_data, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    input_dim = train_data.data.shape[1]
    num_classes = len(np.unique(train_data.labels))
    model = GCN_Transformer(input_dim=input_dim, seq_len=seq_len, num_classes=num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=6e-4, weight_decay=2e-5)
    criterion = nn.CrossEntropyLoss()
    loss_val = []
    acc_val = []
    best_val_acc = 0.0
    for epoch in range(100):
        model.train()
        train_loss = 0.0
        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device).long()
            out = model(x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # 验证
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device)
                y = y.to(device).long()
                out = model(x)
                loss = criterion(out, y)
                val_loss += loss.item()
                _, predicted = torch.max(out.data, 1)
                val_total += y.size(0)
                val_correct += (predicted == y).sum().item()
        
        val_acc = val_correct / val_total
        print(f"Epoch {epoch}: val acc={val_acc:.4f}, loss={val_loss/len(val_loader):.4f}")
        
        # 保存验证集准确率最高的模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'model/model_GCN_Transformer_BSM1.pth')
        
        loss_val.append(val_loss/len(val_loader))
        acc_val.append(val_acc)
    
    # 创建model文件夹（如果不存在）
    os.makedirs('model', exist_ok=True)
    
    # 保存模型相关文件到model文件夹
    np.save('model/acc_val_GCN_Transformer_BSM1.npy', np.array(acc_val))
    np.save('model/loss_GCN_Transformer_BSM1.npy', np.array(loss_val))
    torch.save(model.state_dict(), 'model/model_GCN_Transformer_BSM1.pth')
    plt.figure()
    plt.plot(loss_val)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.figure()
    plt.plot(acc_val)
    plt.xlabel('epoch')
    plt.ylabel('acc')
    plt.show()

if __name__ == '__main__':
    main()